import os
from input_data import *
from model import *
import matplotlib.pyplot as plt
import configparser


def test(test_dir, logs_dir, img_size):
    N_CLASS = 2
    IMG_SIZE = img_size
    BATCH_SIZE = 1
    CAPACITY = 200
    MAX_STEP = 100
    LIST_CHANNELS = [3, 16, 32, 128, 128]

    keep_prob = tf.placeholder(tf.float32)

    sess = tf.Session()

    train_list = get_train_files(test_dir, random=True)
    image_train_batch, label_train_batch = get_train_batch(train_list, IMG_SIZE, BATCH_SIZE, CAPACITY, True)
    softmax = inference(image_train_batch, N_CLASS, LIST_CHANNELS, "test", keep_prob)

    # 载入检查点
    print("载入检查点...")
    save = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(logs_dir)
    if ckpt and ckpt.model_checkpoint_path:
        global_step = ckpt.model_checkpoint_path.split("/")[-1].split("-")[-1]
        save.restore(sess, ckpt.model_checkpoint_path)
        print("载入成功，global_step=%s" % global_step)
    else:
        print("没有找到检查点")
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    fig = plt.figure(figsize=(16, 8))
    fig_list = list(range(10))

    for i in range(10):
        fig_list[i] = fig.add_subplot(2, 5, i + 1)

    try:
        for step in range(MAX_STEP):
            if coord.should_stop():
                break
            image, prediction = sess.run([image_train_batch, softmax], feed_dict={keep_prob: 1})
            max_index = np.argmax(prediction)  # 取出prediction中最大值对应的索引

            if max_index == 0:
                label = "%.2f%% is a Label_0." % (prediction[0][0] * 100)
            else:
                label = "%.2f%% is a Label_1." % (prediction[0][1] * 100)

            fig_list[step % 10].set_title(label, fontsize=10, y=1.02)
            fig_list[step % 10].imshow(image[0])
            if (step + 1) % 10 == 0:
                plt.draw()
                plt.pause(5)
                # input("input any key to continue...")
                plt.clf()
                if step + 1 != MAX_STEP:
                    for i in range(10):
                        fig_list[i] = fig.add_subplot(2, 5, i + 1)
        plt.close()


    except tf.errors.OutOfRangeError:
        print("Done.")
    finally:
        coord.request_stop()
    coord.join(threads=threads)
    sess.close()


def main():
    config = configparser.ConfigParser()
    config.read("config.ini", encoding="utf-8")
    IMG_SIZE = int(config.get("section_2", "IMG_SIZE"))
    DATA = r"data\test"
    LOGS = "logs"
    test(DATA, LOGS, IMG_SIZE)


if __name__ == '__main__':
    main()
